-
Notifications
You must be signed in to change notification settings - Fork 468
[NPU]: update the native KLDivLoss implementation for comparison. (eg.)test_jsd.py #1032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test results on NPU before: error in Test results on NPU after: |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, you can report the issue to torch npu team.
Overall lgtm, just a nit change for documentation.
test/transformers/test_jsd.py
Outdated
| set_seed(42) | ||
|
|
||
|
|
||
| class CustomKLDivLoss(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Add a docstring to explain why we need a custom KLDivLoss
Since it's an npu-exclusive issue, I also name it with NPU.
| class CustomKLDivLoss(torch.nn.Module): | |
| class NPUKLDivLoss(torch.nn.Module): | |
| """ | |
| A custom KLDivLoss for NPU. | |
| On NPU devices, torch.nn.KLDivLoss does not compute gradients with respect to the target. | |
| This leads to incorrect gradient computation when the target depends on the input, | |
| such as in JSD or reverse KLDiv. | |
| See https://github.com/linkedin/Liger-Kernel/issues/1021 for more details. | |
| """ |
|
Thanks. I have reported the issue to the Torch NPU team and updated the documentation as requested. |
Tcc0403
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops my bad, there's a formatting issue. You can also install pre-commit hooks to fix it automatically, see
Liger-Kernel/docs/contributing.md
Lines 30 to 37 in 559e9a1
| 3. **Install pre-commit hooks using [`prek`](https://prek.j178.dev/), a `pre-commit` alternative built in rust** | |
| ``` | |
| prek install | |
| ``` | |
| Run pre-commit check without committing (`-a` is equivalent to `--all-files`) | |
| ``` | |
| prek run -a | |
| ``` |
|
My bad, I should have noticed that. Thanks for the tip! |
Summary
This PR modifies the NPU test reference for KLDivLoss. Since the native NPU KLDivLoss operator does not support gradients w.r.t. the target #1021 it caused failures in test_jsd.py (where input and target are swapped when beta != 0).
To resolve this, I replaced the native operator usage with a custom implementation using basic math operations. This allows correct gradient computation for the target and aligns the x1.grad results with the Triton kernel implementation.
Testing Done
I tested test_jsd,test_fused_linear_jsd by following method and all cases passed:
pytest -v test/transformers/test_jsd.py
pytest -v test/transformers/test_fused_linear_jsd.py
Hardware Type: Ascend NPU 910B3
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence